{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using DALI in PyTorch Lightning\n",
    "\n",
    "### Overview\n",
    "\n",
    "This example shows how to use DALI in PyTorch Lightning.\n",
    "\n",
    "Let us grab [a toy example](https://pytorch-lightning.readthedocs.io/en/1.6.1/starter/core_guide.html) showcasing a classification network and see how DALI can accelerate it.\n",
    "\n",
    "The DALI_EXTRA_PATH environment variable should point to a [DALI extra](https://github.com/NVIDIA/DALI_extra) copy. Please make sure that the proper release tag, the one associated with your DALI version, is checked out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torch import nn\n",
    "from pytorch_lightning import Trainer, LightningModule\n",
    "from torch.optim import Adam\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import os\n",
    "\n",
    "BATCH_SIZE = 64\n",
    "\n",
    "# workaround for https://github.com/pytorch/vision/issues/1938 - error 403 when\n",
    "# downloading mnist dataset\n",
    "import urllib\n",
    "\n",
    "opener = urllib.request.build_opener()\n",
    "opener.addheaders = [(\"User-agent\", \"Mozilla/5.0\")]\n",
    "urllib.request.install_opener(opener)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will start by implement a training class that uses the native data loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LitMNIST(LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        # mnist images are (1, 28, 28) (channels, width, height)\n",
    "        self.layer_1 = torch.nn.Linear(28 * 28, 128)\n",
    "        self.layer_2 = torch.nn.Linear(128, 256)\n",
    "        self.layer_3 = torch.nn.Linear(256, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        batch_size, channels, width, height = x.size()\n",
    "\n",
    "        # (b, 1, 28, 28) -> (b, 1*28*28)\n",
    "        x = x.view(batch_size, -1)\n",
    "        x = self.layer_1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.layer_2(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.layer_3(x)\n",
    "\n",
    "        x = F.log_softmax(x, dim=1)\n",
    "        return x\n",
    "\n",
    "    def process_batch(self, batch):\n",
    "        return batch\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = self.process_batch(batch)\n",
    "        logits = self(x)\n",
    "        loss = F.nll_loss(logits, y)\n",
    "        return loss\n",
    "\n",
    "    def cross_entropy_loss(self, logits, labels):\n",
    "        return F.nll_loss(logits, labels)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return Adam(self.parameters(), lr=1e-3)\n",
    "\n",
    "    def prepare_data(self):\n",
    "        # download data only\n",
    "        self.mnist_train = MNIST(\n",
    "            os.getcwd(),\n",
    "            train=True,\n",
    "            download=True,\n",
    "            transform=transforms.ToTensor(),\n",
    "        )\n",
    "\n",
    "    def setup(self, stage=None):\n",
    "        # transforms for images\n",
    "        transform = transforms.Compose(\n",
    "            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    "        )\n",
    "        self.mnist_train = MNIST(\n",
    "            os.getcwd(), train=True, download=False, transform=transform\n",
    "        )\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        return DataLoader(\n",
    "            self.mnist_train, batch_size=64, num_workers=8, pin_memory=True\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And see how it works"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n"
     ]
    }
   ],
   "source": [
    "model = LitMNIST()\n",
    "trainer = Trainer(max_epochs=5, devices=1, accelerator=\"gpu\")\n",
    "# ddp work only in no-interactive mode, to test it unncoment and run as a script\n",
    "# trainer = Trainer(devices=8, accelerator=\"gpu\", strategy=\"ddp\", max_epochs=5)\n",
    "## MNIST data set is not always available to download due to network issues\n",
    "## to run this part of example either uncomment below line\n",
    "# trainer.fit(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The next step is to define a DALI pipeline that will be used for loading and pre-processing data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nvidia.dali as dali\n",
    "from nvidia.dali import pipeline_def\n",
    "import nvidia.dali.fn as fn\n",
    "import nvidia.dali.types as types\n",
    "from nvidia.dali.plugin.pytorch import (\n",
    "    DALIClassificationIterator,\n",
    "    LastBatchPolicy,\n",
    ")\n",
    "\n",
    "# Path to MNIST dataset\n",
    "data_path = os.path.join(os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/training/\")\n",
    "\n",
    "\n",
    "@pipeline_def\n",
    "def GetMnistPipeline(device, shard_id=0, num_shards=1):\n",
    "    jpegs, labels = fn.readers.caffe2(\n",
    "        path=data_path,\n",
    "        shard_id=shard_id,\n",
    "        num_shards=num_shards,\n",
    "        random_shuffle=True,\n",
    "        name=\"Reader\",\n",
    "    )\n",
    "    images = fn.decoders.image(\n",
    "        jpegs,\n",
    "        device=\"mixed\" if device == \"gpu\" else \"cpu\",\n",
    "        output_type=types.GRAY,\n",
    "    )\n",
    "    images = fn.crop_mirror_normalize(\n",
    "        images,\n",
    "        dtype=types.FLOAT,\n",
    "        std=[0.3081 * 255],\n",
    "        mean=[0.1307 * 255],\n",
    "        output_layout=\"CHW\",\n",
    "    )\n",
    "    if device == \"gpu\":\n",
    "        labels = labels.gpu()\n",
    "    # PyTorch expects labels as INT64\n",
    "    labels = fn.cast(labels, dtype=types.INT64)\n",
    "    return images, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to modify the training class to use the DALI pipeline we have just defined. Because we want to integrate with PyTorch, we wrap our pipeline with a PyTorch DALI iterator, that can replace the native data loader with some minor changes in the code. The DALI iterator returns a list of dictionaries, where each element in the list corresponds to a pipeline instance, and the entries in the dictionary map to the outputs of the pipeline.\n",
    "\n",
    "For more information, check the documentation of DALIGenericIterator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DALILitMNIST(LitMNIST):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def prepare_data(self):\n",
    "        # no preparation is needed in DALI\n",
    "        pass\n",
    "\n",
    "    def setup(self, stage=None):\n",
    "        device_id = self.local_rank\n",
    "        shard_id = self.global_rank\n",
    "        num_shards = self.trainer.world_size\n",
    "        mnist_pipeline = GetMnistPipeline(\n",
    "            batch_size=BATCH_SIZE,\n",
    "            device=\"gpu\",\n",
    "            device_id=device_id,\n",
    "            shard_id=shard_id,\n",
    "            num_shards=num_shards,\n",
    "            num_threads=8,\n",
    "        )\n",
    "        self.train_loader = DALIClassificationIterator(\n",
    "            mnist_pipeline,\n",
    "            reader_name=\"Reader\",\n",
    "            last_batch_policy=LastBatchPolicy.PARTIAL,\n",
    "        )\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        return self.train_loader\n",
    "\n",
    "    def process_batch(self, batch):\n",
    "        x = batch[0][\"data\"]\n",
    "        y = batch[0][\"label\"].squeeze(-1)\n",
    "        return (x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now run the training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name    | Type   | Params\n",
      "-----------------------------------\n",
      "0 | layer_1 | Linear | 100 K \n",
      "1 | layer_2 | Linear | 33.0 K\n",
      "2 | layer_3 | Linear | 2.6 K \n",
      "-----------------------------------\n",
      "136 K     Trainable params\n",
      "0         Non-trainable params\n",
      "136 K     Total params\n",
      "0.544     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a16d571e3bcb4db7aadfa06000859610",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Even if previous Trainer finished his work it still keeps the GPU booked,\n",
    "# force it to release the device.\n",
    "if \"PL_TRAINER_GPUS\" in os.environ:\n",
    "    os.environ.pop(\"PL_TRAINER_GPUS\")\n",
    "model = DALILitMNIST()\n",
    "trainer = Trainer(\n",
    "    max_epochs=5, devices=1, accelerator=\"gpu\", num_sanity_val_steps=0\n",
    ")\n",
    "# ddp work only in no-interactive mode, to test it unncoment and run as a script\n",
    "# trainer = Trainer(devices=8, accelerator=\"gpu\", strategy=\"ddp\", max_epochs=5)\n",
    "trainer.fit(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For even better integration, we can provide a custom DALI iterator wrapper so that no extra processing is required inside `LitMNIST.process_batch`. Also, PyTorch can learn the size of the dataset this way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BetterDALILitMNIST(LitMNIST):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def prepare_data(self):\n",
    "        # no preparation is needed in DALI\n",
    "        pass\n",
    "\n",
    "    def setup(self, stage=None):\n",
    "        device_id = self.local_rank\n",
    "        shard_id = self.global_rank\n",
    "        num_shards = self.trainer.world_size\n",
    "        mnist_pipeline = GetMnistPipeline(\n",
    "            batch_size=BATCH_SIZE,\n",
    "            device=\"gpu\",\n",
    "            device_id=device_id,\n",
    "            shard_id=shard_id,\n",
    "            num_shards=num_shards,\n",
    "            num_threads=8,\n",
    "        )\n",
    "\n",
    "        class LightningWrapper(DALIClassificationIterator):\n",
    "            def __init__(self, *kargs, **kvargs):\n",
    "                super().__init__(*kargs, **kvargs)\n",
    "\n",
    "            def __next__(self):\n",
    "                out = super().__next__()\n",
    "                # DDP is used so only one pipeline per process\n",
    "                # also we need to transform dict returned by\n",
    "                # DALIClassificationIterator to iterable and squeeze the lables\n",
    "                out = out[0]\n",
    "                return [\n",
    "                    out[k] if k != \"label\" else torch.squeeze(out[k])\n",
    "                    for k in self.output_map\n",
    "                ]\n",
    "\n",
    "        self.train_loader = LightningWrapper(\n",
    "            mnist_pipeline,\n",
    "            reader_name=\"Reader\",\n",
    "            last_batch_policy=LastBatchPolicy.PARTIAL,\n",
    "        )\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        return self.train_loader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us run the training one more time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name    | Type   | Params\n",
      "-----------------------------------\n",
      "0 | layer_1 | Linear | 100 K \n",
      "1 | layer_2 | Linear | 33.0 K\n",
      "2 | layer_3 | Linear | 2.6 K \n",
      "-----------------------------------\n",
      "136 K     Trainable params\n",
      "0         Non-trainable params\n",
      "136 K     Total params\n",
      "0.544     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d58bd060e32a44ae868a96cb67fcacc3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Even if previous Trainer finished his work it still keeps the GPU booked,\n",
    "# force it to release the device.\n",
    "if \"PL_TRAINER_GPUS\" in os.environ:\n",
    "    os.environ.pop(\"PL_TRAINER_GPUS\")\n",
    "model = BetterDALILitMNIST()\n",
    "trainer = Trainer(\n",
    "    max_epochs=5, devices=1, accelerator=\"gpu\", num_sanity_val_steps=0\n",
    ")\n",
    "# ddp work only in no-interactive mode, to test it uncomment and run as a script\n",
    "# trainer = Trainer(devices=8, accelerator=\"gpu\", strategy=\"ddp\", max_epochs=5)\n",
    "trainer.fit(model)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
